# -*- coding: utf-8 -*-
# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import json
import re
from pathlib import Path
import subprocess
from typing import Optional, Tuple
import glob
from loguru import logger
import pandas as pd
from msserviceprofiler.modelevalstate.config.base_config import MINDIE_BENCHMARK_PERF_COLUMNS
from msserviceprofiler.modelevalstate.config.config import AisBenchConfig, VllmBenchmarkConfig, get_settings, \
    PerformanceIndex, OptimizerConfigField
from msserviceprofiler.modelevalstate.config.custom_command import AisBenchCommand, VllmBenchmarkCommand
from msserviceprofiler.modelevalstate.optimizer.interfaces.benchmark import BenchmarkInterface
from msserviceprofiler.msguard.security import open_s, walk_s
from msserviceprofiler.modelevalstate.optimizer.utils import backup, remove_file

  
MS_TO_S = 10 ** 3
US_TO_S = 10 ** 6


def parse_result(res):
    if isinstance(res, str):
        _res = res.strip().split()
        if len(_res) > 1:
            if _res[1].strip().lower() == "ms":
                return float(_res[0]) / MS_TO_S
            elif _res[1].strip().lower() == "us":
                return float(_res[0]) / US_TO_S
            else:
                return float(_res[0])
        return float(res)
    return res


class AisBench(BenchmarkInterface):
    def __init__(self, *args, config: Optional[AisBenchConfig] = None, **kwargs):
        if config:
            self.config = config
        else:
            settings = get_settings()
            self.config = settings.ais_bench
        super().__init__(*args, **kwargs)
        self.work_path = self.config.work_path
        self.update_command()
        self.model_config_path = self.get_models_config_path()
        with open_s(self.model_config_path, 'r', encoding='utf-8') as f:
            self.default_data = f.read()
        self.mindie_benchmark_perf_columns = [k.lower().strip() for k in MINDIE_BENCHMARK_PERF_COLUMNS]

    @property
    def num_prompts(self) -> int:
        """
        获取服务的进程名属性
        Returns:""

        """
        return self.config.command.num_prompts

    @num_prompts.setter
    def num_prompts(self, value):
        """
        获取服务的进程名属性
        Returns:""

        """
        self.config.command.num_prompts = value

    def update_command(self):
        self.command = AisBenchCommand(self.config.command).command

    def get_models_config_path(self):
        cmd = [self.command[0],
               "--models", self.config.command.models,
               "--search"]
        res = subprocess.run(cmd, text=True, capture_output=True)
        if res.returncode != 0:
            raise ValueError(f"The command {cmd} execution failed, with an exit code of {res.returncode}")
        _output = res.stdout
        if not _output:
            _output = res.stderr
        for _line in _output.split("\n"):
            if "--models" not in _line:
                continue
            _lines = _line.strip().split()
            if len(_lines) != 7:
                raise ValueError(
                    f"The expected data format is Task Type │ Task Name │ Config File Path. But get data is {_lines}")
            config_path = Path(_lines[-2].strip())
            return config_path
        raise ValueError(
            f"The expected data format is Task Type │ Task Name │ Config File Path. But get data is {_output}")

    def backup(self, del_log=True):
        backup(self.config.output_path, self.bak_path, self.__class__.__name__)
        if not del_log:
            backup(self.run_log, self.bak_path, self.__class__.__name__)

    def get_performance_metric(self, metric_name: str, algorithm: str = "average"):
        output_path = Path(self.config.output_path)
        result_files = glob.glob(f"{output_path}/*/performances/*/*.csv")
        if len(result_files) != 1:
            logger.error(f"The ais bench result for csv files are not unique, result files {result_files}; "
                         f"output path: {output_path}. please check")
        metric_name = metric_name.lower().strip()
        algorithm = algorithm.strip().lower()
        if algorithm not in self.mindie_benchmark_perf_columns:
            raise ValueError(f"The {algorithm} does not support it; "
                             f"only {self.mindie_benchmark_perf_columns} are supported.")
        for file in result_files:
            df = pd.read_csv(file)
            _all_metrics = [k.strip().lower() for k in df["Performance Parameters"].tolist()]
            if metric_name not in _all_metrics:
                continue
            _i = _all_metrics.index(metric_name)
            _columns = [k.lower().strip() for k in df.columns]
            _col_index = _columns.index(algorithm)
            _res = df.iloc[_i, _col_index]
            if not _res:
                continue
            return parse_result(_res)
        raise ValueError(f"Not Found value.  metric_name {metric_name}, algorithm: {algorithm}")

    def get_best_concurrency(self):
        output_path = Path(self.config.output_path)
        rate_files = glob.glob(f"{output_path}/*/performances/*/*dataset.json")
        for json_file in rate_files:
            with open_s(json_file, "r") as f:
                try:
                    data = json.load(f)
                except json.decoder.JSONDecodeError as e:
                    logger.error(f"{e}, file: {json_file}")
                    continue
            _concurrency = float(data["Concurrency"]["total"])
            _concurrency *= self.config.best_concurrency_coefficient
            _max_concurrency = float(data["Max Concurrency"]["total"])
            if _concurrency < self.config.best_concurrency_threshold:
                best_concurrency = self.config.best_concurrency_threshold
            else:
                best_concurrency = int(min(_concurrency, _max_concurrency))
            return best_concurrency
        raise ValueError(f"Not Found concurrency value. fiels: {rate_files}")

    def get_performance_index(self):
        output_path = Path(self.config.output_path)
        performance_index = PerformanceIndex()
        if not output_path.exists():
            logger.error(f"the output of aisbench is not find: {output_path}")
        performance_index.time_to_first_token = self.get_performance_metric(
            self.config.performance_config.time_to_first_token.metric,
            self.config.performance_config.time_to_first_token.algorithm)
        performance_index.time_per_output_token = self.get_performance_metric(
            self.config.performance_config.time_per_output_token.metric,
            self.config.performance_config.time_per_output_token.algorithm)
        rate_files = glob.glob(f"{output_path}/*/performances/*/*dataset.json")
        for json_file in rate_files:
            with open_s(json_file, "r") as f:
                try:
                    data = json.load(f)
                except json.decoder.JSONDecodeError as e:
                    logger.error(f"{e}, file: {json_file}")
                    continue
            total_requests = data["Total Requests"]["total"]
            success_req = data["Success Requests"]["total"]
            performance_index.throughput = float(data["Request Throughput"]["total"].split()[0])
            if total_requests != 0:
                performance_index.success_rate = success_req / total_requests
                output_average = data["Output Token Throughput"]["total"]
                performance_index.generate_speed = float(output_average.split()[0])
        return performance_index

    def before_run(self, run_params: Optional[Tuple[OptimizerConfigField]] = None):
        remove_file(Path(self.config.output_path))
        super().before_run(run_params)
        # 启动测试
        logger.debug("Start the aisbench test.")
        concurrency = rate = None
        for k in run_params:
            try:
                if k.name == "CONCURRENCY" and k.value:
                    concurrency = int(k.value)
                if k.name == "REQUESTRATE" and k.value:
                    rate = k.value
            except ValueError:
                logger.warning(f"the {k.name} is not number; please check: {k.value}")
                concurrency = rate = None
        with open_s(self.model_config_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        _request_rate_pattern = re.compile(r"(request_rate\s*=\s*)\d{1,10}(?:\.\d{1,30})?\s*,")
        _batch_size_pattern = re.compile(r"(batch_size\s*=\s*)\d{1,10}(?:\.\d{1,30})?\s*,")
        # 修改 request_rate 和 batch_size
        for i, line in enumerate(lines):
            if 'request_rate' in line:
                _res = _request_rate_pattern.search(lines[i])
                if _res:
                    if rate is None:
                        rate = 0
                    lines[i] = lines[i].replace(_res.group(), f"request_rate = {rate},")
            if 'batch_size' in line:
                _res = _batch_size_pattern.search(lines[i])
                if _res:
                    if concurrency is None:
                        concurrency = 1000
                    lines[i] = lines[i].replace(_res.group(), f"batch_size = {concurrency},")
 
        # 将修改后的内容写回文件
        with open_s(self.model_config_path, 'w', encoding='utf-8') as f:
            f.writelines(lines)


class VllmBenchMark(BenchmarkInterface):
    def __init__(self, config: Optional[VllmBenchmarkConfig] = None, *args, **kwargs):

        if config:
            self.config = config
        else:
            settings = get_settings()
            self.config = settings.vllm_benchmark
        super().__init__(*args, **kwargs)
        self.command = VllmBenchmarkCommand(self.config.command).command

    @property
    def num_prompts(self) -> int:
        """
        获取服务的进程名属性
        Returns:""

        """
        return self.config.command.num_prompts

    @num_prompts.setter
    def num_prompts(self, value):
        """
        获取服务的进程名属性
        Returns:""

        """
        self.config.command.num_prompts = value

    def update_command(self):
        self.command = VllmBenchmarkCommand(self.config.command).command

    def stop(self, del_log: bool = True):
        # 删除输出的文件
        output_path = Path(self.config.command.result_dir)
        remove_file(output_path)
        super().stop(del_log)

    def before_run(self, run_params: Optional[Tuple[OptimizerConfigField, ...]] = None):
        # 启动前清理输出目录 因为get_performance_index是从里面获取其中一条数据，防止获取到错误数据
        output_path = Path(self.config.command.result_dir)
        remove_file(output_path)
        super().before_run(run_params)

    def get_performance_index(self):
        output_path = Path(self.config.command.result_dir)
        performance_index = PerformanceIndex()
        for file in walk_s(output_path):
            file = Path(file)
            if not file.name.endswith(".json"):
                continue
            with open_s(file, mode='r', encoding="utf-8") as f:
                try:
                    data = json.load(f)
                except json.JSONDecodeError:
                    logger.error(f"Failed in parse vllm benchmark result. file: {file}")
                    continue

            performance_index.generate_speed = data.get("output_throughput", 0)
            performance_index.time_to_first_token = data.get("mean_ttft_ms", 0) / MS_TO_S
            performance_index.time_per_output_token = data.get("mean_tpot_ms", 0) / MS_TO_S
            num_prompts = data.get("num_prompts", 1)
            completed = data.get("completed", 0)
            performance_index.success_rate = 0
            if num_prompts > 0:
                performance_index.success_rate = completed / num_prompts
            performance_index.throughput = float(data.get("request_throughput", 3.0))
        return performance_index
